GRU
将多层门控循环单元 (GRU) RNN 应用于输入序列。
GRU 网络模型中有两个门:更新门和重置门。将两个连续的时间节点表示为 \(t - 1\) 和 \(t\)。给定一个在时刻 \(t\) 的输入 \(x_t\),一个隐藏状态 \(h_{t-1}\),在时刻 \(t\) 的更新门和重置门使用门控制机制计算。更新门 \(z_t\) 用于控制前一时刻的状态信息被带入到当前状态中的程度,重置门 \(r_t\) 控制前一状态有多少信息被写入到当前候选集 \(n_t\) 上
对于输入序列中的每个元素,每一层计算以下函数:
其中 \(\sigma\) 是 sigmoid 激活函数,\(\odot\) 是 Hadamard 积(逐元素乘积)。\(W, b\) 是公式中输出和输入之间的可学习权重。例如,\(W_{ir}, b_{ir}\) 是用于将输入 \(x_t\) 转换为 \(r_t\) 的权重和偏置。
注意,本算子中候选门 \(n_t\) 的计算与原始论文和Mindspore框架略有不同。在原始实现中,\(r_t\) 和上一隐藏状态 \(h_{(t-1)}\) 之间的 Hadamard 积 (\(\odot\)) 在与权重矩阵 \(W\) 相乘和加上偏置之前进行:
本算子采用 PyTorch 实现方式,是在 \(W_{hn}h_{(t-1)}\) 之后完成的:
- 输入:
input - 输入数据的地址。
weight_g - 可学习的输入-隐藏权重的地址。
weight_r - 可学习的隐藏-隐藏权重的地址。
input_bias - 可学习的输入-隐藏偏置的地址。
state_bias - 可学习的隐藏-隐藏偏置的地址。
hidden_state - 初始隐藏状态的地址。
buffer - 用于存储中间计算结果。
gru_param - 算子计算所需参数的结构体。其各成员见下述。
core_mask - 核掩码。
GruParameter定义:
1typedef struct GruParameter {
2 int input_size_; // 输入input中预期特征的数量
3 int hidden_size_; // 隐藏状态h中的特征数量
4 int seq_len_; // 输入batch中每个序列的长度
5 int batch_; // 总批次数
6 int output_step_; // 每次循环中output步长
7 int bidirectional_; // 是否为双向GRU
8 int input_row_align_; // 输入行对齐值
9 int input_col_align_; // 输入列对齐值
10 int state_row_align_; // 隐藏状态行对齐值
11 int state_col_align_; // 隐藏状态列对齐值
12 int check_seq_len_; // 进行计算的序列长度
13} GruParameter;
- 输出:
output - 输出地址。
hidden_state - 最终的隐藏状态。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持int8, fp32
MT7004 支持fp16, fp32
共享存储版本:
-
void i8_Gru_s(int8_t *output, int8_t *input, int8_t *weight_g, int8_t *weight_r, int8_t *input_bias, int8_t *state_bias, int8_t *hidden_state, int8_t *buffer[4], GruParameter *gru_param, int core_mask)
-
void hp_Gru_s(half *output, half *input, half *weight_g, half *weight_r, half *input_bias, half *state_bias, half *hidden_state, half *buffer[4], GruParameter *gru_param, int core_mask);
-
void fp_Gru_s(float *output, float *input, float *weight_g, float *weight_r, float *input_bias, float *state_bias, float *hidden_state, float *buffer[4], GruParameter *gru_param, int core_mask);
C调用示例:
1void TestGruSMCFp32(int check_seq_len, int seq_len, int batch_size, int input_size, int bidirectional, int hidden_size, int core_mask) {
2 int core_id = get_core_id();
3 int logic_core_id = GetLogicCoreId(core_mask, core_id);
4 int core_num = GetCoreNum(core_mask);
5 float *output = (void*)0x88000000;
6 float *input = (void*)0x88100000;
7 float *weight_g = (void*)0x88200000;
8 float *weight_r = (void*)0x88300000;
9 float *input_bias = (void*)0x88400000;
10 float *state_bias = (void*)0x88500000;
11 float *hidden_state = (void*)0x88600000;
12 float** buffer = (float**)0x88700000;
13 float *output_hidden_state = (void*)0x88800000;
14 GruParameter* param = (GruParameter*)0x88900000;
15 int hidden_state_batch = 1;
16 int num_directions = 1;
17 if (bidirectional) {
18 hidden_state_batch = hidden_state_batch * 2;
19 num_directions = num_directions * 2;
20 }
21 int input_col_align = hidden_size;
22 int state_col_align = hidden_size;
23 if (logic_core_id == 0) {
24 memcpy(output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
25 memcpy(check_output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
26 buffer[0] = (void*)0x88A00000;
27 buffer[1] = (void*)0x88B00000;
28 buffer[2] = (void*)0x88C00000;
29 buffer[3] = (void*)0x88D00000;
30 param->batch_ = batch_size;
31 param->bidirectional_ = bidirectional;
32 param->hidden_size_ = hidden_size;
33 param->input_col_align_ = input_col_align;
34 param->input_size_ = input_size;
35 param->output_step_ = batch_size * hidden_size * num_directions;
36 param->seq_len_ = seq_len;
37 param->state_col_align_ = state_col_align;
38 param->check_seq_len_ = check_seq_len;
39 }
40 sys_bar(0, core_num); // 初始化参数完成后进行同步
41 fp_Gru_s(output, input, weight_g, weight_r, input_bias, state_bias, output_hidden_state, buffer, param, core_mask);
42}
43
44void main() {
45 int check_seq_len = 2;
46 int seq_len = 2;
47 int batch_size = 2;
48 int input_size = 2;
49 int bidirectional = 0;
50 int hidden_size = 2;
51 int core_mask = 0b1111;
52 TestGruSMCFp32(check_seq_len, seq_len, batch_size, input_size, bidirectional, hidden_size, core_mask);
53}
私有存储版本:
-
void i8_Gru_p(int8_t *output, int8_t *input, int8_t *weight_g, int8_t *weight_r, int8_t *input_bias, int8_t *state_bias, int8_t *hidden_state, int8_t *buffer[4], GruParameter *gru_param, int core_mask)
-
void hp_Gru_p(half *output, half *input, half *weight_g, half *weight_r, half *input_bias, half *state_bias, half *hidden_state, half *buffer[4], GruParameter *gru_param, int core_mask);
-
void fp_Gru_p(float *output, float *input, float *weight_g, float *weight_r, float *input_bias, float *state_bias, float *hidden_state, float *buffer[4], GruParameter *gru_param, int core_mask);
C调用示例:
1void TestGruL2Fp32(int check_seq_len, int seq_len, int batch_size, int input_size, int bidirectional, int hidden_size, int core_mask) {
2 float *output = (void*)0x10000000; // 私有存储版本地址设置在AM内
3 float *input = (void*)0x10004000;
4 float *weight_g = (void*)0x10008000;
5 float *weight_r = (void*)0x1000C000;
6 float *input_bias = (void*)0x10010000;
7 float *state_bias = (void*)0x10014000;
8 float *hidden_state = (void*)0x10018000;
9 float** buffer = (float**)0x1001C000;
10 float *output_hidden_state = (void*)0x10020000;
11 GruParameter* param = (GruParameter*)0x10024000;
12 int hidden_state_batch = 1;
13 int num_directions = 1;
14 if (bidirectional) {
15 hidden_state_batch = hidden_state_batch * 2;
16 num_directions = num_directions * 2;
17 }
18 int input_col_align = hidden_size;
19 int state_col_align = hidden_size;
20 memcpy(output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
21 memcpy(check_output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
22 buffer[0] = (void*)0x10030000;
23 buffer[1] = (void*)0x10034000;
24 buffer[2] = (void*)0x10038000;
25 buffer[3] = (void*)0x1003C000;
26 param->batch_ = batch_size;
27 param->bidirectional_ = bidirectional;
28 param->hidden_size_ = hidden_size;
29 param->input_col_align_ = input_col_align;
30 param->input_size_ = input_size;
31 param->output_step_ = batch_size * hidden_size * num_directions;
32 param->seq_len_ = seq_len;
33 param->state_col_align_ = state_col_align;
34 param->check_seq_len_ = check_seq_len;
35 fp_Gru_p(output, input, weight_g, weight_r, input_bias, state_bias, output_hidden_state, buffer, param, core_mask);
36}
37
38void main() {
39 int check_seq_len = 2;
40 int seq_len = 2;
41 int batch_size = 2;
42 int input_size = 2;
43 int bidirectional = 0;
44 int hidden_size = 2;
45 int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动
46 TestGruL2Fp32(check_seq_len, seq_len, batch_size, input_size, bidirectional, hidden_size, core_mask);
47 return 0;
48}